{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import GPT2LMHeadModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_hf = GPT2LMHeadModel.from_pretrained(\"gpt2\") # 124M\n",
    "sd_hf = model_hf.state_dict()\n",
    "\n",
    "for k, v in sd_hf.items():\n",
    "    print(k, v.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sd_hf[\"transformer.wpe.weight\"].view(-1)[:20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "plt.imshow(sd_hf[\"transformer.wpe.weight\"], cmap=\"gray\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(sd_hf[\"transformer.wpe.weight\"][:, 150])\n",
    "plt.plot(sd_hf[\"transformer.wpe.weight\"][:, 200])\n",
    "plt.plot(sd_hf[\"transformer.wpe.weight\"][:, 250])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(sd_hf[\"transformer.h.1.attn.c_attn.weight\"][:300,:300], cmap=\"gray\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline, set_seed\n",
    "generator = pipeline('text-generation', model='gpt2')\n",
    "set_seed(42)\n",
    "generator(\"Hello, I'm a language model,\", max_length=30, num_return_sequences=5)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's instead sample manually\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "\n",
    "model = GPT2LMHeadModel.from_pretrained(\"gpt2\") # 124M\n",
    "model.eval()\n",
    "model.to('cuda')\n",
    "torch.manual_seed(42)\n",
    "torch.cuda.manual_seed(42)\n",
    "tokens = [15496, 11, 314, 1101, 257, 3303, 2746, 11] # \"Hello, I'm a language model,\"\n",
    "tokens = torch.tensor(tokens, dtype=torch.long) # (8,)\n",
    "tokens = tokens.unsqueeze(0).repeat(5, 1) # (5, 8)\n",
    "x = tokens.to('cuda')\n",
    "\n",
    "# generate!\n",
    "while x.size(1) < 30: # max_length=30\n",
    "    # forward the model to get the logits\n",
    "    with torch.no_grad():\n",
    "        logits = model(x)[0] # (B, T, vocab_size)\n",
    "        # take the logits at the last position\n",
    "        logits = logits[:, -1, :] # (B, vocab_size)\n",
    "        # get the probabilities\n",
    "        probs = F.softmax(logits, dim=-1)\n",
    "        # do top-k sampling of 50 (huggingface pipeline default)\n",
    "        # topk_probs here becomes (5, 50), topk_indices is (5, 50)\n",
    "        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)\n",
    "        # select a token from the top-k probabilities\n",
    "        # note: multinomial does not demand the input to sum to 1\n",
    "        ix = torch.multinomial(topk_probs, 1) # (B, 1)\n",
    "        # gather the corresponding indices\n",
    "        xcol = torch.gather(topk_indices, -1, ix) # (B, 1)\n",
    "        # append to the sequence\n",
    "        x = torch.cat((x, xcol), dim=1)\n",
    "\n",
    "# print the generated text\n",
    "import tiktoken\n",
    "enc = tiktoken.get_encoding('gpt2')\n",
    "for i in range(5):\n",
    "    tokens = x[i, :30].tolist()\n",
    "    decoded = enc.decode(tokens)\n",
    "    print(\">\", decoded)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tiny shakespeare dataset\n",
    "# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
    "with open('input.txt', 'r') as f:\n",
    "    text = f.read()\n",
    "data = text[:1000] # first 1,000 characters\n",
    "print(data[:100])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tiktoken\n",
    "enc = tiktoken.get_encoding('gpt2')\n",
    "tokens = enc.encode(data)\n",
    "print(tokens[:24])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "buf = torch.tensor(tokens[:24 + 1])\n",
    "x = buf[:-1].view(4, 6)\n",
    "y = buf[1:].view(4, 6)\n",
    "print(x)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sd_hf[\"lm_head.weight\"].shape)\n",
    "print(sd_hf[\"transformer.wte.weight\"].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(sd_hf[\"lm_head.weight\"] == sd_hf[\"transformer.wte.weight\"]).all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sd_hf[\"lm_head.weight\"].data_ptr())\n",
    "print(sd_hf[\"transformer.wte.weight\"].data_ptr())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# standard deviation grows inside the residual stream\n",
    "x = torch.zeros(768)\n",
    "n = 100 # e.g. 100 layers\n",
    "for i in range(n):\n",
    "    x += n**-0.5 * torch.randn(768)\n",
    "\n",
    "print(x.std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "# super simple little MLP\n",
    "net = torch.nn.Sequential(\n",
    "    torch.nn.Linear(16, 32),\n",
    "    torch.nn.GELU(),\n",
    "    torch.nn.Linear(32, 1)\n",
    ")\n",
    "torch.random.manual_seed(42)\n",
    "x = torch.randn(4, 16)\n",
    "y = torch.randn(4, 1)\n",
    "net.zero_grad()\n",
    "yhat = net(x)\n",
    "loss = torch.nn.functional.mse_loss(yhat, y)\n",
    "loss.backward()\n",
    "print(net[0].weight.grad.view(-1)[:10])\n",
    "\n",
    "# the loss objective here is (due to readuction='mean')\n",
    "# L = 1/4 * [\n",
    "#            (y[0] - yhat[0])**2 +\n",
    "#            (y[1] - yhat[1])**2 +\n",
    "#            (y[2] - yhat[2])**2 +\n",
    "#            (y[3] - yhat[3])**2\n",
    "#           ]\n",
    "# NOTE: 1/4!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now let's do it with grad_accum_steps of 4, and B=1\n",
    "# the loss objective here is different because\n",
    "# accumulation in gradient <---> SUM in loss\n",
    "# i.e. we instead get:\n",
    "# L0 = 1/4(y[0] - yhat[0])**2\n",
    "# L1 = 1/4(y[1] - yhat[1])**2\n",
    "# L2 = 1/4(y[2] - yhat[2])**2\n",
    "# L3 = 1/4(y[3] - yhat[3])**2\n",
    "# L = L0 + L1 + L2 + L3\n",
    "# NOTE: the \"normalizer\" of 1/4 is lost\n",
    "net.zero_grad()\n",
    "for i in range(4):\n",
    "    yhat = net(x[i])\n",
    "    loss = torch.nn.functional.mse_loss(yhat, y[i])\n",
    "    loss = loss / 4 # <-- have to add back the \"normalizer\"!\n",
    "    loss.backward()\n",
    "print(net[0].weight.grad.view(-1)[:10])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# parse and visualize the logfile\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "sz = \"124M\"\n",
    "\n",
    "loss_baseline = {\n",
    "    \"124M\": 3.2924,\n",
    "}[sz]\n",
    "hella2_baseline = { # HellaSwag for GPT-2\n",
    "    \"124M\": 0.294463,\n",
    "    \"350M\": 0.375224,\n",
    "    \"774M\": 0.431986,\n",
    "    \"1558M\": 0.488946,\n",
    "}[sz]\n",
    "hella3_baseline = { # HellaSwag for GPT-3\n",
    "    \"124M\": 0.337,\n",
    "    \"350M\": 0.436,\n",
    "    \"774M\": 0.510,\n",
    "    \"1558M\": 0.547,\n",
    "}[sz]\n",
    "\n",
    "# load the log file\n",
    "with open(\"log124M_40B/log.txt\", \"r\") as f:\n",
    "    lines = f.readlines()\n",
    "\n",
    "# parse the individual lines, group by stream (train,val,hella)\n",
    "streams = {}\n",
    "for line in lines:\n",
    "    step, stream, val = line.strip().split()\n",
    "    if stream not in streams:\n",
    "        streams[stream] = {}\n",
    "    streams[stream][int(step)] = float(val)\n",
    "\n",
    "# convert each stream from {step: val} to (steps[], vals[])\n",
    "# so it's easier for plotting\n",
    "streams_xy = {}\n",
    "for k, v in streams.items():\n",
    "    # get all (step, val) items, sort them\n",
    "    xy = sorted(list(v.items()))\n",
    "    # unpack the list of tuples to tuple of lists\n",
    "    streams_xy[k] = list(zip(*xy))\n",
    "\n",
    "# create figure\n",
    "plt.figure(figsize=(16, 6))\n",
    "\n",
    "# Panel 1: losses: both train and val\n",
    "plt.subplot(121)\n",
    "xs, ys = streams_xy[\"train\"] # training loss\n",
    "ys = np.array(ys)\n",
    "plt.plot(xs, ys, label=f'nanogpt ({sz}) train loss')\n",
    "print(\"Min Train Loss:\", min(ys))\n",
    "xs, ys = streams_xy[\"val\"] # validation loss\n",
    "plt.plot(xs, ys, label=f'nanogpt ({sz}) val loss')\n",
    "# horizontal line at GPT-2 baseline\n",
    "if loss_baseline is not None:\n",
    "    plt.axhline(y=loss_baseline, color='r', linestyle='--', label=f\"OpenAI GPT-2 ({sz}) checkpoint val loss\")\n",
    "plt.xlabel(\"steps\")\n",
    "plt.ylabel(\"loss\")\n",
    "plt.yscale('log')\n",
    "plt.ylim(top=4.0)\n",
    "plt.legend()\n",
    "plt.title(\"Loss\")\n",
    "print(\"Min Validation Loss:\", min(ys))\n",
    "\n",
    "# Panel 2: HellaSwag eval\n",
    "plt.subplot(122)\n",
    "xs, ys = streams_xy[\"hella\"] # HellaSwag eval\n",
    "ys = np.array(ys)\n",
    "plt.plot(xs, ys, label=f\"nanogpt ({sz})\")\n",
    "# horizontal line at GPT-2 baseline\n",
    "if hella2_baseline:\n",
    "    plt.axhline(y=hella2_baseline, color='r', linestyle='--', label=f\"OpenAI GPT-2 ({sz}) checkpoint\")\n",
    "if hella3_baseline:\n",
    "    plt.axhline(y=hella3_baseline, color='g', linestyle='--', label=f\"OpenAI GPT-3 ({sz}) checkpoint\")\n",
    "plt.xlabel(\"steps\")\n",
    "plt.ylabel(\"accuracy\")\n",
    "plt.legend()\n",
    "plt.title(\"HellaSwag eval\")\n",
    "print(\"Max Hellaswag eval:\", max(ys))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
